import json

from bedrock.client import BedrockClient
from botocore.exceptions import ClientError

from typing import Optional

import logging

import os
from dotenv import load_dotenv

load_dotenv()

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class BedrockAnthropicChatCompletions(BedrockClient):
    """ A class to generate chat completions using Bedrock's Anthropic text models"""
    # Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/bedrock-runtime_example_bedrock-runtime_InvokeModel_AnthropicClaude_section.html

    log: logging.Logger = logging.getLogger("BedrockAnthropicChatCompletions")

    def __init__(self, aws_access_key: Optional[str] = os.getenv("AWS_ACCESS_KEY_ID"), aws_secret_key: Optional[str] = os.getenv("AWS_SECRET_ACCESS_KEY"),
                 region_name: Optional[str] = os.getenv("AWS_REGION"), model_id: Optional[str] = "anthropic.claude-3-haiku-20240307-v1:0") -> None:
        super().__init__(aws_access_key=aws_access_key, aws_secret_key=aws_secret_key, region_name=region_name)
        """
        Initialize the BedrockAnthropicChatCompletions class.
        
        Args:
            aws_access_key (str): The AWS access key. Default is os.getenv("AWS_ACCESS_KEY_ID").
            aws_secret_key (str): The AWS secret key. Default is os.getenv("AWS_SECRET_ACCESS_KEY").
            region_name (str): The AWS region name. Default is os.getenv("AWS_REGION").
            model_id (str): The model ID to use. Only accepts Anthropic Claude models.
            bedrock_client (BedrockClient): The BedrockClient instance.
        """
        self.model_id = model_id
        self.bedrock_client = self._get_bedrock_client()

    def predict(self, text: str):
        """ Predict a chat completion based on the input text.

        Args:
            text (str): The input text to generate a chat completion for.

        Returns:
            str: The chat completion generated by the model.
        """

        temperature = 0.00001
        max_tokens = 512

        # Format the request payload using the model's native structure.
        native_request = {
            "anthropic_version": "bedrock-2023-05-31",
            "max_tokens": max_tokens,
            "temperature": temperature,
            "messages": [
                {
                    "role": "user",
                    "content": [{"type": "text", "text": text}],
                }
            ],
        }

        # Convert the native request to JSON.
        request = json.dumps(native_request)

        try:
            # Invoke the model with the request.
            response = self.bedrock_client.invoke_model(
                modelId=self.model_id, body=request)

        except (ClientError, Exception) as e:
            self.log.error(
                f"ERROR: Can't invoke '{self.text_model}'. Reason: {e}")
            exit(1)

        # Decode the response body.
        model_response = json.loads(response["body"].read())

        # Extract the response text.
        response_text = model_response["content"][0]["text"]

        return response_text


if __name__ == '__main__':

    # Note that if you are going to execute this script, you need to change the import statement to: 
    # from client import BedrockClient

    # If you are not going to use BedrockClient and its models, you might remove the packages boto3 and botocore. If so:
    # Open a Terminal and run the following commands:
    # 1. cd backend ---> (Make sure to be in the backend directory)
    # 2. poetry remove boto3 botocore ---> (This will remove the packages from the project)

    # Example usage of the BedrockAnthropicChatCompletions class.
    chat_completions_model = "anthropic.claude-3-haiku-20240307-v1:0" # You can change this to any Claude Anthropic model.
    aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
    aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
    region_name = os.getenv("AWS_REGION")

    # Example usage of the BedrockAnthropicChatCompletions class.
    chat_completions = BedrockAnthropicChatCompletions(
        model_id=chat_completions_model,
        region_name=region_name,
        aws_access_key=aws_access_key,
        aws_secret_key=aws_secret_key
    )

    # Example prompt for the chat completion model.
    prompt = "What is the meaning of life?"

    # Generate a chat completion based on the prompt.
    answer = chat_completions.predict(prompt)

    print(type(answer))
    print(answer)